import sys
import torch
import torch.nn as nn
import torchtext
from torchtext.vocab import Vectors
import numpy as np
import random

SEED = 1234

USE_CUDA = torch.cuda.is_available()

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if USE_CUDA:
    torch.cuda.manual_seed(SEED)

BATCH_SIZE = 32
EMBEDDING_SIZE = 100
HIDDEN_SIZE = 100
MAX_VOCAB_SIZE = 50000
NUM_EPOCHS = 2

TEXT = torchtext.data.Field(lower=True)

train, val, test = torchtext.datasets.LanguageModelingDataset.splits(path="./data/text8.all/",
        train="text8.train.txt",
        validation="text8.train.txt",
        test="text8.train.txt", text_field=TEXT)

TEXT.build_vocab(train, max_size=MAX_VOCAB_SIZE) #建立词表
print("vocabulary size: {}".format(len(TEXT.vocab))) # 有50002个单词，自动包括了unk和pad

print(TEXT.vocab.itos[:10]) # 类型是list，输出前10个最频繁的单词 ['<unk>', '<pad>', 'the', 'of', 'and', 'one', 'in', 'a', 'to', 'zero']
print(TEXT.vocab.stoi["<unk>"]) # 返回word_id
print(TEXT.vocab.stoi["similar"])
VOCAB_SIZE = len(TEXT.vocab)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iter, val_iter, test_iter = torchtext.data.BPTTIterator.splits(
    (train, val, test), batch_size=BATCH_SIZE, device=device,
    bptt_len=32, repeat=False, shuffle=True) # bptt_len梯度反传，最多多少步

# it = iter(train_iter)
# batch = next(it)
# print(" ".join([TEXT.vocab.itos[i] for i in batch.text[:,9].data])) # 取出某一个句子
# print(" ".join([TEXT.vocab.itos[i] for i in batch.target[:,9].data])) # 取出该个句子的text的下一个target的单词


# for i in range(5):
#     batch = next(it)
#     print(" ".join([TEXT.vocab.itos[i] for i in batch.text[:,2].data]))
#     print(" ".join([TEXT.vocab.itos[i] for i in batch.target[:,2].data]))

# define model
class RNNModel(nn.Module):
    """ 一个简单的循环神经网络"""

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):
        ''' 该模型包含以下几层:
            - 词嵌入层
            - 一个循环神经网络层(RNN, LSTM, GRU)
            - 一个线性层，从hidden state到输出单词表
            - 一个dropout层，用来做regularization
        '''
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        if rnn_type in ['LSTM', 'GRU']:
            self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)
        else:
            try:
                nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type]
            except KeyError:
                raise ValueError( """An invalid option for `--model` was supplied,
                                 options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""")
            self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        ''' Forward pass:
            - word embedding
            - 输入循环神经网络
            - 一个线性层从hidden state转化为输出单词表
        '''
        # print("input ", input.size())
        emb = self.drop(self.encoder(input))
        # print("emb ", emb.size())
        output, hidden = self.rnn(emb, hidden)

        # print("output ", output.size())
        # print(len(hidden))

        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))

        # print("output size: ", output.size())
        # print("decoded before size: ", output.view(output.size(0)*output.size(1), output.size(2)).size())
        # print("decoded after size: ", decoded.size())

        # sys.exit()

        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz, requires_grad=True):
        weight = next(self.parameters())
        if self.rnn_type == 'LSTM':
            return (weight.new_zeros((self.nlayers, bsz, self.nhid), requires_grad=requires_grad),
                    weight.new_zeros((self.nlayers, bsz, self.nhid), requires_grad=requires_grad))
        else:
            return weight.new_zeros((self.nlayers, bsz, self.nhid), requires_grad=requires_grad)

model = RNNModel("LSTM", VOCAB_SIZE, EMBEDDING_SIZE, EMBEDDING_SIZE, 2, dropout=0.5)
if USE_CUDA:
    model = model.cuda()


def evaluate(model, data):
    model.eval()
    total_loss = 0.
    it = iter(data)
    total_count = 0.
    with torch.no_grad():
        hidden = model.init_hidden(BATCH_SIZE, requires_grad=False)
        for i, batch in enumerate(it):
            data, target = batch.text, batch.target
            if USE_CUDA:
                data, target = data.cuda(), target.cuda()
            hidden = repackage_hidden(hidden)
            with torch.no_grad():
                output, hidden = model(data, hidden)
            loss = loss_fn(output.view(-1, VOCAB_SIZE), target.view(-1))
            total_count += np.multiply(*data.size())
            total_loss += loss.item() * np.multiply(*data.size())

    loss = total_loss / total_count
    model.train()
    return loss

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5)

import copy

GRAD_CLIP = 1.
NUM_EPOCHS = 2

val_losses = []
for epoch in range(NUM_EPOCHS):
    model.train()
    it = iter(train_iter)
    hidden = model.init_hidden(BATCH_SIZE)
    for i, batch in enumerate(it):
        data, target = batch.text, batch.target
        if USE_CUDA:
            data, target = data.cuda(), target.cuda()
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        output, hidden = model(data, hidden)

        loss = loss_fn(output.view(-1, VOCAB_SIZE), target.view(-1))

        # print("out size: ", output.view(-1, VOCAB_SIZE).size())
        # print("target size: ", target.view(-1).size())
        # sys.exit()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        if i % 1000 == 0:
            print("epoch", epoch, "iter", i, "loss", loss.item())

        if i % 10000 == 0:
            val_loss = evaluate(model, val_iter)

            if len(val_losses) == 0 or val_loss < min(val_losses):
                print("best model, val loss: ", val_loss)
                torch.save(model.state_dict(), "lm-best.th")
            else:
                scheduler.step()
                optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
            val_losses.append(val_loss)

# load model
best_model = RNNModel("LSTM", VOCAB_SIZE, EMBEDDING_SIZE, EMBEDDING_SIZE, 2, dropout=0.5)
if USE_CUDA:
    best_model = best_model.cuda()
best_model.load_state_dict(torch.load("lm-best.th"))


hidden = best_model.init_hidden(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input = torch.randint(VOCAB_SIZE, (1, 1), dtype=torch.long).to(device)
words = []
for i in range(100):
    output, hidden = best_model(input, hidden)
    word_weights = output.squeeze().exp().cpu()
    word_idx = torch.multinomial(word_weights, 1)[0]
    input.fill_(word_idx)
    word = TEXT.vocab.itos[word_idx]
    words.append(word)
print(" ".join(words))